{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Question Answering with Control\n",
    "\n",
    "This model shows a form of question answering where statements and questions are supplied through a single 'visual input' and the replies are produced in a 'motor output' as discussed in the book. You will implement this by using the basal ganglia to store and retrieve information from working memory in response to visual input. More specifically, the basal ganglia decides what to do with the information in the visual channel based on its content (i.e. whether it is a statement or a question).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Setup the environment\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "%matplotlib inline\n",
    "\n",
    "import nengo\n",
    "from nengo.spa import Vocabulary\n",
    "from nengo.spa import BasalGanglia\n",
    "from nengo.spa import Thalamus\n",
    "from nengo.dists import Uniform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model\n",
    "\n",
    "This model has parameters as described in the book. Note that in Nengo 1.4, network arrays were used to construct this model for computational reasons as explained in the book. Nengo 2.0 has 'EnsembleArray' as an equivalent to network arrays which you will use in this model. \n",
    "\n",
    "When you run the model, it will start by binding `RED` and `CIRCLE` and then binding `BLUE` and `SQUARE` so the memory essentially has `RED * CIRCLE + BLUE * SQUARE`. This is stored in memory because the model is told that `RED * CIRCLE` is a STATEMENT (i.e. `RED * CIRCLE + STATEMENT` in the code) as is `BLUE * SQUARE`. Then it is presented with something like `QUESTION + RED` (i.e., \"What is red?\"). The basal ganglia then reroutes that input to be compared to what is in working memory and the result shows up in the motor channel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "dim = 100         # Number of dimensions \n",
    "N = 30            # Neurons per dimension\n",
    "N_conv = 70       # Number of neurons per dimension in bind/unbind populations\n",
    "N_mem = 50        # Number of neurons per dimension in memory population\n",
    "ZERO = [0] * dim  # Defining a zero vector having length equal to the number of dimensions\n",
    "\n",
    "# Creating the vocabulary\n",
    "rng = np.random.RandomState(15)\n",
    "vocab = Vocabulary(dimensions=dim, rng=rng, max_similarity=0.05)\n",
    "\n",
    "model = nengo.Network(label='Question Answering with Control', seed=15)\n",
    "with model:\n",
    "    # Ensembles\n",
    "    visual = nengo.networks.EnsembleArray(n_neurons=N, n_ensembles=dim, max_rates=Uniform(100,300), label='Visual')\n",
    "    channel = nengo.networks.EnsembleArray(n_neurons=N, n_ensembles=dim, label='Channel')\n",
    "    motor = nengo.networks.EnsembleArray(n_neurons=N, n_ensembles=dim, label='Motor')   \n",
    "    \n",
    "    # Creating a memory (integrator)\n",
    "    tau = 0.1\n",
    "    memory = nengo.networks.EnsembleArray(n_neurons=N, n_ensembles=dim, label='Memory')\n",
    "    nengo.Connection(memory.output, memory.input, synapse=tau)     \n",
    "    \n",
    "    # function for providing visual input\n",
    "    def visual_input(t):\n",
    "        if 0.1 < t < 0.3:\n",
    "            return vocab.parse('STATEMENT+RED*CIRCLE').v\n",
    "        elif 0.35 < t < 0.5:\n",
    "            return vocab.parse('STATEMENT+BLUE*SQUARE').v\n",
    "        elif 0.55 < t < 0.7:\n",
    "            return vocab.parse('QUESTION+BLUE').v\n",
    "        elif 0.75 < t < 0.9:\n",
    "            return vocab.parse('QUESTION+CIRCLE').v\n",
    "        else:\n",
    "            return ZERO\n",
    "        \n",
    "\n",
    "    # function for flipping the output of the thalamus\n",
    "    def xBiased(x):\n",
    "        return [1-x]\n",
    "     \n",
    "    # Providing input to the model\n",
    "    input = nengo.Node(output=visual_input, size_out=dim)\n",
    "    nengo.Connection(input, visual.input)\n",
    "    \n",
    "    nengo.Connection(visual.output, channel.input, synapse=0.02)\n",
    "    nengo.Connection(channel.output, memory.input)\n",
    "    \n",
    "    # Creating the unbind network\n",
    "    unbind = nengo.networks.CircularConvolution(n_neurons=N_conv, dimensions=dim, invert_a=True)\n",
    "    nengo.Connection(visual.output, unbind.A)\n",
    "    nengo.Connection(memory.output, unbind.B)\n",
    "    nengo.Connection(unbind.output, motor.input)\n",
    "    \n",
    "    # Creating the basal ganglia and the thalamus network\n",
    "    BG = nengo.networks.BasalGanglia(dimensions=2)  \n",
    "    thal = nengo.networks.Thalamus(dimensions=2)\n",
    "    nengo.Connection(BG.output, thal.input, synapse=0.01)\n",
    "    \n",
    "    # Defining the transforms for connecting the visual input to the BG\n",
    "    trans0 = np.matrix(vocab.parse('STATEMENT').v)\n",
    "    trans1 = np.matrix(vocab.parse('QUESTION').v)\n",
    "    nengo.Connection(visual.output, BG.input[0], transform=trans0)\n",
    "    nengo.Connection(visual.output, BG.input[1], transform=trans1) \n",
    "    \n",
    "    # Connecting thalamus output to the two gates gating the channel and the motor populations\n",
    "    passthrough = nengo.Ensemble(n_neurons=N, dimensions=2)\n",
    "    nengo.Connection(thal.output, passthrough)\n",
    " \n",
    "    gate0 = nengo.Ensemble(N, 1, label='Gate0')\n",
    "    nengo.Connection(passthrough[0], gate0, function=xBiased, synapse=0.01)    \n",
    "    gate1 = nengo.Ensemble(N, 1, label='Gate1')\n",
    "    nengo.Connection(passthrough[1], gate1, function=xBiased, synapse=0.01)\n",
    "    \n",
    "    for ensemble in channel.ea_ensembles:\n",
    "        nengo.Connection(gate0, ensemble.neurons, transform=[[-3]] * gate0.n_neurons)\n",
    "        \n",
    "    for ensemble in motor.ea_ensembles:\n",
    "        nengo.Connection(gate1, ensemble.neurons, transform=[[-3]] * gate1.n_neurons) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Probes to Collect Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "with model:\n",
    "    Visual_p = nengo.Probe(visual.output, synapse=0.03)\n",
    "    Motor_p = nengo.Probe(motor.output, synapse=0.03)\n",
    "    Memory_p = nengo.Probe(memory.output, synapse=0.03)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run The Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "with nengo.Simulator(model) as sim:  # Create the simulator\n",
    "    sim.run(1.2)                     # Run it for 1.2 seconds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot The Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "\n",
    "plt.subplot(6, 1, 1)\n",
    "plt.plot(sim.trange(), nengo.spa.similarity(sim.data[Visual_p], vocab))\n",
    "plt.legend(vocab.keys, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, fontsize=9)\n",
    "plt.ylabel(\"Visual\");\n",
    "\n",
    "plt.subplot(6, 1, 2)\n",
    "plt.plot(sim.trange(), nengo.spa.similarity(sim.data[Memory_p], vocab))\n",
    "plt.legend(vocab.keys, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, fontsize=9)\n",
    "plt.ylabel(\"Memory\");\n",
    "\n",
    "plt.subplot(6, 1, 3)\n",
    "plt.plot(sim.trange(), nengo.spa.similarity(sim.data[Motor_p], vocab))\n",
    "plt.legend(vocab.keys, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, fontsize=9)\n",
    "plt.ylabel(\"Motor\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The graphs above show that when the input to the Visual system is a `STATEMENT`, there is no response from the Motor system and the input is stored in the Memory. However, when the input to the Visual system is a `QUESTION`, the Motor system responds with the appropriate answer. For instance, when the input to Visual system is `CIRCLE` the output from the motor system is `RED`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model using the `spa` package"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import nengo\n",
    "import nengo.spa as spa\n",
    "from nengo.spa import Vocabulary\n",
    "import numpy as np\n",
    "\n",
    "D = 32  # the dimensionality of the vectors\n",
    "rng = np.random.RandomState(15)\n",
    "vocab = Vocabulary(dimensions=D, rng=rng, max_similarity=0.1)\n",
    "\n",
    "# Adding semantic pointers to the vocabulary\n",
    "CIRCLE = vocab.parse('CIRCLE')\n",
    "BLUE = vocab.parse('BLUE')\n",
    "RED = vocab.parse('RED')\n",
    "SQUARE = vocab.parse('SQUARE')\n",
    "ZERO = vocab.add('ZERO', [0]*D)\n",
    "\n",
    "model = spa.SPA(label=\"Question Answering with Control\", vocabs=[vocab])\n",
    "with model:\n",
    "    model.visual = spa.State(D)\n",
    "    model.motor = spa.State(D)\n",
    "    model.memory = spa.State(D, feedback=1, feedback_synapse=0.1)\n",
    "\n",
    "    actions = spa.Actions(\n",
    "        'dot(visual, STATEMENT) --> memory=visual',\n",
    "        'dot(visual, QUESTION) --> motor = memory * ~visual'\n",
    "    )\n",
    " \n",
    "    model.bg = spa.BasalGanglia(actions)\n",
    "    model.thalamus = spa.Thalamus(model.bg)\n",
    "\n",
    "    # function for providing visual input\n",
    "    def visual_input(t):\n",
    "        if 0.1 < t < 0.3:\n",
    "            return 'STATEMENT+RED*CIRCLE'\n",
    "        elif 0.35 < t < 0.5:\n",
    "            return 'STATEMENT+BLUE*SQUARE'\n",
    "        elif 0.55 < t < 0.7:\n",
    "            return 'QUESTION+BLUE'\n",
    "        elif 0.75 < t < 0.9:\n",
    "            return 'QUESTION+CIRCLE'\n",
    "        else:\n",
    "            return 'ZERO'\n",
    "\n",
    "    # Inputs\n",
    "    model.input = spa.Input(visual=visual_input)   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the model in nengo_gui"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from nengo_gui.ipython import IPythonViz\n",
    "IPythonViz(model, \"ch5-question-control.py.cfg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Press the play button in the visualizer to run the simulation. You should see the \"semantic pointer cloud\" graphs as shown in the figure below.\n",
    "\n",
    "The visual graph shows the input represented by `visual`. When this input is a STATEMENT, there is no response shown in the motor graph and the input is stored in `memory` (shown in memory graph). However, when the input to the `visual` is a QUESTION, the motor graph shows the appropriate answer. For instance, when the input to `visual`  is QUESTION+BLUE (showin in the visual graphs), the output from `motor` is SQUARE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "Image(filename='ch5-question-control-1.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "Image(filename='ch5-question-control-2.png')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  },
  "widgets": {
   "state": {},
   "version": "1.1.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
